Skip to content

Conversation

@vzakhari
Copy link
Contributor

@vzakhari vzakhari commented Dec 6, 2024

Such SUMs might appear in dead code after constant propagation.
They do not have to be inlined.

Such SUMs might appear in dead code after constant propagation.
They do not have to be inlined.
@vzakhari vzakhari requested a review from jeanPerier December 6, 2024 02:10
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Dec 6, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 6, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Slava Zakharin (vzakhari)

Changes

Such SUMs might appear in dead code after constant propagation.
They do not have to be inlined.


Full diff: https://github.com/llvm/llvm-project/pull/118911.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp (+13-3)
  • (modified) flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir (+18)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 0c34c8221aeda6..ace63a970db932 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -108,7 +108,6 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
     mlir::Value mask = sum.getMask();
     mlir::Value dim = sum.getDim();
     int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
-    assert(dimVal > 0 && "DIM must be present and a positive constant");
     mlir::Value resultShape, dimExtent;
     std::tie(resultShape, dimExtent) =
         genResultShape(loc, builder, array, dimVal);
@@ -235,6 +234,9 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
     mlir::Value inShape = hlfir::genShape(loc, builder, array);
     llvm::SmallVector<mlir::Value> inExtents =
         hlfir::getExplicitExtentsFromShape(inShape, builder);
+    assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
+           "DIM must be present and a positive constant not exceeding "
+           "the array's rank");
     if (inShape.getUses().empty())
       inShape.getDefiningOp()->erase();
 
@@ -348,12 +350,20 @@ class SimplifyHLFIRIntrinsics
     // would avoid creating a temporary for the elemental array expression.
     target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
       if (mlir::Value dim = sum.getDim()) {
-        if (fir::getIntIfConstant(dim)) {
+        if (auto dimVal = fir::getIntIfConstant(dim)) {
           if (!fir::isa_trivial(sum.getType())) {
             // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
             // It is only legal when X is 1, and it should probably be
             // canonicalized into SUM(a).
-            return false;
+            fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
+                hlfir::getFortranElementOrSequenceType(
+                    sum.getArray().getType()));
+            if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
+              // Ignore SUMs with illegal DIM values.
+              // They may appear in dead code,
+              // and they do not have to be converted.
+              return false;
+            }
           }
         }
       }
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
index 703b6673154f3f..313e54d5d0c4af 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
@@ -411,3 +411,21 @@ func.func @sum_non_const_dim(%arg0: !fir.box<!fir.array<3xi32>>, %dim: i32) {
 // CHECK:           %[[VAL_2:.*]] = hlfir.sum %[[VAL_0]] dim %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>, i32) -> i32
 // CHECK:           return
 // CHECK:         }
+
+// negative: invalid dim==0
+func.func @sum_invalid_dim0(%arg0: !hlfir.expr<2x3xi32>) {
+  %cst = arith.constant 0 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_invalid_dim0(
+// CHECK:           hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
+
+// negative: invalid dim>rank
+func.func @sum_invalid_dim_big(%arg0: !hlfir.expr<2x3xi32>) {
+  %cst = arith.constant 3 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_invalid_dim_big(
+// CHECK:           hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@vzakhari vzakhari merged commit 084451c into llvm:main Dec 9, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants